import numpy as np
from pyxtal.interface.mushybox import mushybox
import ase.units
from ase.calculators.calculator import Calculator
from lammps import lammps
import os
from pkg_resources import resource_filename
from ase.data import chemical_symbols

class LAMMPSlib(Calculator):

    def __init__(self, lmp, lmpcmds, folder='tmp', mol=False, *args, **kwargs):
        Calculator.__init__(self, *args, **kwargs)
        self.lmp = lmp
        self.molecule = mol
        self.folder = folder
        if not os.path.exists(folder):
            os.makedirs(folder)
        self.lammps_data = folder+'/data.lammps'
        self.lammps_in = folder + '/in.lammps'
        self.ffpath = resource_filename("pyxtal", "potentials")
        self.paras = []
        for para in lmpcmds:
            self.paras.append(para)


    def calculate(self, atoms):
        """
        prepare lammps .in file and data file
        write_lammps_data(filename, self.atoms, )
        """
        #self.lmp = lammps()
        if self.molecule:
            self.write_lammps_data_water(atoms)
        else:
            self.write_lammps_data(atoms)
        self.write_lammps_in()
        self.lmp.file(self.lammps_in)
        # Extract the forces and energy
        self.lmp.command('variable pxx equal pxx')
        self.lmp.command('variable pyy equal pyy')
        self.lmp.command('variable pzz equal pzz')
        self.lmp.command('variable pxy equal pxy')
        self.lmp.command('variable pxz equal pxz')
        self.lmp.command('variable pyz equal pyz')
        self.lmp.command('variable fx atom fx')
        self.lmp.command('variable fy atom fy')
        self.lmp.command('variable fz atom fz')
        self.lmp.command('variable pe equal pe')

        pos = np.array(
                [x for x in self.lmp.gather_atoms("x", 1, 3)]).reshape(-1, 3)
        pos = pos*0.529
        
        self.energy = self.lmp.extract_variable('pe', None, 0)/27.2114
        stress = np.empty(6)
        stress_vars = ['pxx', 'pyy', 'pzz', 'pyz', 'pxz', 'pxy']

        for i, var in enumerate(stress_vars):
            stress[i] = self.lmp.extract_variable(var, None, 0)/1e+9

        stress_mat = np.zeros((3, 3))
        stress_mat[0, 0] = stress[0]
        stress_mat[1, 1] = stress[1]
        stress_mat[2, 2] = stress[2]
        stress_mat[1, 2] = stress[3]
        stress_mat[2, 1] = stress[3]
        stress_mat[0, 2] = stress[4]
        stress_mat[2, 0] = stress[4]
        stress_mat[0, 1] = stress[5]
        stress_mat[1, 0] = stress[5]
        stress[0] = stress_mat[0, 0]
        stress[1] = stress_mat[1, 1]
        stress[2] = stress_mat[2, 2]
        stress[3] = stress_mat[1, 2]
        stress[4] = stress_mat[0, 2]
        stress[5] = stress_mat[0, 1]

        self.stress = -stress #* 1e5 * ase.units.Pascal
        f = (np.array(self.lmp.gather_atoms("f", 1, 3)).reshape(-1,3) /27.2114*0.529)
                #(ase.units.eV/ase.units.Angstrom))
        self.forces = f.copy()
        atoms.positions = pos.copy()
        self.atoms = atoms.copy()
        #self.lmp.close()

    def write_lammps_in(self):
        with open(self.lammps_in, 'w') as fh:
            fh.write('clear\n')
            if self.molecule:
                fh.write('units electron\n')
            else:
                fh.write('units metal\n')
            fh.write('boundary p p p\n')
            fh.write('atom_modify sort 0 0.0\n') 

            fh.write('\n### interactions\n')
            write = True
            for para in self.paras:
                if para.find('read_data')>=0:
                    write = False
                    break
            if write:
                fh.write('read_data {:s}\n'.format(self.lammps_data))
            for para in self.paras:
                fh.write("{:s}\n".format(para))
            fh.write('thermo_style custom pe pxx\n')
            fh.write('thermo_modify flush yes\n')
            fh.write('thermo 1\n')
            fh.write('run 0\n')
            #fh.write('print "__end_of_ase_invoked_calculation__"\n') 

    def write_lammps_data(self, atoms):
        atom_types = [1]*len(atoms)
        n_types = np.unique(atoms.numbers)
        lmp_types = np.zeros(len(atoms))
        for i, typ in enumerate(n_types):
            lmp_types[atoms.numbers==typ] = i+1
        with open(self.lammps_data, 'w') as fh:
            comment = 'lammpslib autogenerated data file'
            fh.write(comment.strip() + '\n\n')
            fh.write('{0} atoms\n'.format(len(atoms)))
            fh.write('{0} atom types\n'.format(len(n_types)))
            cell = atoms.get_cell()
            fh.write('\n')
            fh.write('{0:16.8e} {1:16.8e} xlo xhi\n'.format(0.0, cell[0, 0]))
            fh.write('{0:16.8e} {1:16.8e} ylo yhi\n'.format(0.0, cell[1, 1]))
            fh.write('{0:16.8e} {1:16.8e} zlo zhi\n'.format(0.0, cell[2, 2]))
            fh.write('\n\nAtoms \n\n')
            for i, (typ, pos) in enumerate(
                    zip(lmp_types, atoms.get_positions())):
                fh.write('{0:d} {1:d} {2:16.8e} {3:16.8e} {4:16.8e}\n'
                         .format(i + 1, typ, pos[0], pos[1], pos[2]))

    def write_lammps_data_water(self, atoms):
        """
        Lammps input only for water model
        """
        atom_types = [1]*len(atoms)
        N_atom = len(atoms)
        N_mol = int(len(atoms)/3)
        N_bond = N_mol * 2
        N_angle = N_mol
        n_types = np.unique(atoms.numbers)
        lmp_types = np.zeros(N_atom, dtype=int)
        lmp_types[atoms.numbers==1] = 2
        lmp_types[atoms.numbers==8] = 1

        mol_types = np.zeros(N_atom, dtype=int)
        for i in range(N_mol):
            mol_types[i*3:(i+1)*3] = i+1

        with open(self.lammps_data, 'w') as fh:
            comment = 'lammpslib autogenerated data file'
            fh.write(comment.strip() + '\n\n')
            fh.write('{0} atoms\n'.format(N_atom))
            fh.write('{0} bonds\n'.format(N_bond))
            fh.write('{0} angles\n'.format(N_angle))

            fh.write('\n2 atom types\n')
            fh.write('1 bond types\n')
            fh.write('1 angle types\n')

            cell = atoms.get_cell()/0.529
            fh.write('\n')
            fh.write('{0:16.8e} {1:16.8e} xlo xhi\n'.format(0.0, cell[0, 0]))
            fh.write('{0:16.8e} {1:16.8e} ylo yhi\n'.format(0.0, cell[1, 1]))
            fh.write('{0:16.8e} {1:16.8e} zlo zhi\n'.format(0.0, cell[2, 2]))

            fh.write('\n\nMasses \n\n')
            fh.write('  1 15.9994\n')
            fh.write('  2  1.0000\n')

            fh.write('\n\nBond Coeffs \n\n')
            fh.write(' 1    1.78    0.2708585 -0.327738785 0.231328959\n')

            fh.write('\n\nAngle Coeffs \n\n')
            fh.write('  1    0.0700  107.400000')
            fh.write('\n\nAtoms \n\n')
            for i, (typ, mtyp, pos) in enumerate(
                    zip(lmp_types, mol_types, atoms.get_positions()/0.529)):
                #print(i, mtyp, typ)
                if typ==2:
                    fh.write('{0:4d} {1:4d} {2:4d}   0.5564 {3:16.8f} {4:16.8f} {5:16.8f}\n'
                         .format(i + 1, mtyp, typ, pos[0], pos[1], pos[2]))
                else:
                    fh.write('{0:4d} {1:4d} {2:4d}  -1.1128 {3:16.8f} {4:16.8f} {5:16.8f}\n'
                         .format(i + 1, mtyp, typ, pos[0], pos[1], pos[2]))

            fh.write('\nBonds \n\n')
            for i in range(N_mol):
                fh.write('{:4d} {:4d} {:4d} {:4d}\n'.format(i*2+1,1,i*3+1,i*3+2))
                fh.write('{:4d} {:4d} {:4d} {:4d}\n'.format(i*2+2,1,i*3+1,i*3+3))
                   
            fh.write('\nAngles \n\n')
            for i in range(N_angle):
                fh.write('{:4d} {:4d} {:4d} {:4d} {:4d}\n'.format(i+1,1,i*3+2,i*3+1,i*3+3))
                   

    def update(self, atoms):
        if not hasattr(self, 'atoms') or self.atoms != atoms:
            self.calculate(atoms)

    def get_potential_energy(self, atoms):
        self.update(atoms)
        return self.energy

    def get_forces(self, atoms):
        self.update(atoms)
        return self.forces.copy()

    def get_stress(self, atoms):
        self.update(atoms)
        return self.stress.copy()

def run_lammpslib(pbcgb, lmp, parameters, path, method='opt', temp=300, steps=10000):
    if method == 'md':
        parameter += [
                       "fix 1 all nvt temp " + str(temp) + ' '  + str(temp) + " 0.05",
                       "timestep  0.001",
                       "thermo 1000",
                       "run " + str(steps),
                       "reset_timestep 0",
                       "min_style cg",
                       "minimize 1e-15 1e-15 10000 10000",
                       "thermo 0",
                      ]
    elif method == 'opt':
        parameter += [
                       "min_style cg",
                       "minimize 1e-15 1e-15 10000 10000",
                      ]
    else:
        parameter += ['run 0']

    lammps = LAMMPSlib(lmp=lmp, lmpcmds=parameter0, log_file='lammps.log', path=path)
    pbcgb.set_calculator(lammps)
    Eng = pbcgb.get_potential_energy()
    return pbcgb, Eng

def optimize_lammpslib(pbcgb, lmp, parameters, path, method='FIRE', fmax=0.1):
    lammps = LAMMPSlib(lmp=lmp, lmpcmds=parameters, log_file='lammps.log', path=path)
    pbcgb.set_calculator(lammps)
    cell0=pbcgb.cell
    fixstrain = np.zeros((3,3))
    fixstrain[2][2] = 1
    box = mushybox(pbcgb, fixstrain=fixstrain)
    if method == 'FIRE':
        dyn = FIRE(box)
    else:
        dyn = BFGS(box)
    dyn.run(fmax=fmax, steps=500)
    return pbcgb


if __name__ == '__main__':
    from pyxtal.crystal import Lattice
    from pyxtal.molecular_crystal import molecular_crystal
    from spglib import get_symmetry_dataset
    from ase.optimize.fire import FIRE
    from ase import Atoms
    from ase.optimize import BFGS
    from ase.build import sort
    import logging

    lammps_name=''
    comm=None
    log_file='lammps.log'
    cmd_args = ['-echo', 'log', '-log', log_file,
                '-screen', 'none', '-nocite']
    lmp = lammps(lammps_name, cmd_args, comm)

    logging.basicConfig(format='%(asctime)s :: %(message)s', filename='results.log', level=logging.INFO)


    parameters = ["atom_style full",
                  "pair_style      lj/cut/tip4p/long 1 2 1 1 0.278072379 17.007",
                  "bond_style      class2 ",
                  "angle_style     harmonic",
                  "kspace_style pppm/tip4p 0.0001",
                  "read_data tmp/data.lammps",
                  "pair_coeff  * * 0 0",
                  "pair_coeff  1  1  0.000295147 5.96946",
                  "neighbor 2.0 bin",
                  "min_style cg",
                  "minimize 1e-6 1e-6 10000 10000",
                 ]
    para = Lattice.from_para(4.45, 7.70, 7.28, 90, 90, 90)
    strucs = []
    for i in range(10):
        crystal = molecular_crystal(36, ['H2O'], [2], 1.0, lattice=para)
        struc = Atoms(crystal.spg_struct[2], 
                    cell=crystal.spg_struct[0], 
                    scaled_positions=crystal.spg_struct[1],
                   )
        struc.write('0.vasp', format='vasp', vasp5=True)
        #struc = optimize_lammpslib(pbcgb, lmp, method='FIRE', fmax=0.1, path=tmp)
        #struc = struc.repeat((2,2,2))
        lammps = LAMMPSlib(lmp=lmp, lmpcmds=parameters, mol=True)
        struc.set_calculator(lammps)
        box = mushybox(struc)
        dyn = FIRE(box)
        dyn.run(fmax=0.01, steps=500)
        dyn = BFGS(box)
        dyn.run(fmax=0.01, steps=500)
        fmax = np.max(struc.get_forces())
        #print('stress:', struc.get_stress())
        eng = struc.get_potential_energy()*96/len(struc)*3
        vol = np.linalg.det(struc.cell)
        struc = struc.repeat((2,2,2))
        struc=sort(struc)
        spg = get_symmetry_dataset(struc, symprec=1e-1)['number']
        struc.write(str(i)+'out.vasp', format='vasp', vasp5=True)
        logging.info('{:d} Spg: {:4d} Eng: {:8.4f} Vol: {:8.4f} fmax: {:4.2f}'.format(i, spg, eng, vol, fmax))
